package throttle

import (
	"errors"
	"fmt"
	"gitee.com/zhucheer/orange/app"
	"gitee.com/zhucheer/orange/utils"
	"github.com/juju/ratelimit"
	"net/http"
	"sync"
	"time"
)

const gcTime = 5 * time.Second

// 限速统计最小时间单元
const rateUnite = 500 * time.Millisecond

type Throttle struct {
	MaxQps    int64
	IpSplit   bool
	BreakTime time.Duration

	requestMaps map[string]*limitItem
	mutex       sync.Mutex
}

type limitItem struct {
	UserTag       string
	BucketHandler *ratelimit.Bucket
	DelaySecond   *time.Timer
	BreakExpireAt time.Time
}

// NewThrottle 实例化限速中间件 maxRateSecond：每秒最大请求数 breakTime：限制时长
func NewThrottle(maxQps int64, breakTime time.Duration, ipSplit bool) *Throttle {
	return &Throttle{
		MaxQps: maxQps, IpSplit: ipSplit, BreakTime: breakTime, requestMaps: make(map[string]*limitItem),
	}
}

// Func implements Middleware interface.
func (w Throttle) Func() app.MiddlewareFunc {
	return func(next app.HandlerFunc) app.HandlerFunc {
		return func(c *app.Context) error {
			limitItemInfo := w.getLimter(w.IpSplit, c.OrangeInput.URL(), c.OrangeInput.IP())
			if limitItemInfo.BreakExpireAt.After(time.Now()) {
				return showBreakErr(c)
			}

			limitItemInfo.DelaySecond.Reset(gcTime + w.BreakTime)
			go func(userTag string, delay *time.Timer) {
				<-delay.C
				w.clearUserTag(userTag)
			}(limitItemInfo.UserTag, limitItemInfo.DelaySecond)

			takeCount := limitItemInfo.BucketHandler.TakeAvailable(1)
			if takeCount < 1 {
				limitItemInfo.BreakExpireAt = time.Now().Add(w.BreakTime)
				return showBreakErr(c)
			}

			return next(c)
		}
	}
}

// showBreakErr
func showBreakErr(c *app.Context) error {
	c.HttpError(http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
	return errors.New(http.StatusText(http.StatusTooManyRequests))
}

// getLimter 获取一个限速 Bucket 对象
func (w *Throttle) getLimter(ipSplit bool, reqUrl, ipAddr string) *limitItem {
	userTag := fmt.Sprintf("orangeThrottle_%s", reqUrl)
	if ipSplit == true {
		userTag = fmt.Sprintf("orangeThrottle_%s_%s", reqUrl, ipAddr)
	}
	userTag = utils.ShortTag(userTag, 1)

	limiter, exists := w.requestMaps[userTag]
	if !exists {
		return w.addUserTag(userTag)
	}

	return limiter
}

// addUserTag 添加一个用户访问标记
func (w *Throttle) addUserTag(userTag string) *limitItem {
	w.mutex.Lock()
	defer w.mutex.Unlock()
	if w.requestMaps == nil {
		w.requestMaps = make(map[string]*limitItem)
	}

	rateCount := int64(time.Second / rateUnite)
	quantumUnite := w.MaxQps / rateCount
	bucket := ratelimit.NewBucketWithQuantum(rateUnite, w.MaxQps, quantumUnite)

	item := &limitItem{
		UserTag:       userTag,
		BucketHandler: bucket,
		DelaySecond:   time.NewTimer(gcTime + w.BreakTime),
		BreakExpireAt: time.Now(),
	}
	w.requestMaps[userTag] = item
	return item
}

// clearUserTag 清理一个用户访问限速对象
func (w *Throttle) clearUserTag(userTag string) {
	w.mutex.Lock()
	defer w.mutex.Unlock()
	delete(w.requestMaps, userTag)
	if len(w.requestMaps) == 0 {
		w.requestMaps = nil
	}
}
